package com.lianda.state;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.flink.api.common.serialization.SimpleStringSchema;
import org.apache.flink.api.common.state.BroadcastState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReadOnlyBroadcastState;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.tuple.Tuple6;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.filesystem.FsStateBackend;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.datastream.*;
import org.apache.flink.streaming.api.environment.CheckpointConfig;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.ProcessFunction;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer011;
import org.apache.flink.util.Collector;
import scala.Int;

import java.net.URL;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;

/**
 * BroadCastState使用实例：
 * https://blog.csdn.net/wangpei1949/article/details/99698978
 */
@Slf4j
public class BroadcastStateMain {
    public static void main(String[] args) throws Exception {
        //1、解析命令行参数
        URL fileUrl = BroadcastStateMain.class.getClassLoader().getResource("stateTest.properties");
        ParameterTool parameterTool = ParameterTool
                .fromPropertiesFile(fileUrl.getPath());

        //checkpoint配置
        String checkpointDirectory = parameterTool.getRequired("checkpointDirectory");
        long checkpointSecondInterval = parameterTool.getLong("checkpointSecondInterval");

        //Kafka事件流参数
        String fromKafkaBootstrapServers = parameterTool.getRequired("fromKafka.bootstrap.servers");
        String fromKafkaGroupID = parameterTool.getRequired("fromKafka.group.id");
        String fromKafkaTopic = parameterTool.getRequired("fromKafka.topic");

        //配置信息流参数
        //配置信息放在mysql
        String fromMysqlHost = parameterTool.getRequired("fromMysql.host");
        int fromMysqlPort = parameterTool.getInt("fromMysql.port");
        String fromMysqlDB = parameterTool.getRequired("fromMysql.db");
        String fromMysqlUser = parameterTool.getRequired("fromMysql.user");
        String fromMysqlPasswd = parameterTool.getRequired("fromMysql.passwd");
        int fromMysqlSecondInterval = parameterTool.getInt("fromMysql.secondInterval");

        //2、配置运行环境
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setStateBackend((StateBackend)new FsStateBackend(checkpointDirectory, true));

        //设置Checkpoint
        CheckpointConfig checkpointConfig = env.getCheckpointConfig();
        checkpointConfig.setCheckpointInterval(checkpointSecondInterval * 1000);
        checkpointConfig.setCheckpointingMode(CheckpointingMode.EXACTLY_ONCE);
        checkpointConfig
                .enableExternalizedCheckpoints(
                        CheckpointConfig
                                .ExternalizedCheckpointCleanup
                                .RETAIN_ON_CANCELLATION
                );

        //3、Kafka事件流
        //从Kafka中获取事件数据
        Properties kafkaProperties = new Properties();
        kafkaProperties.put("bootstrap.servers",fromKafkaBootstrapServers);
        kafkaProperties.put("group.id",fromKafkaGroupID);

        FlinkKafkaConsumer011<String> kafkaConsumer011 = new FlinkKafkaConsumer011<String>(
                fromKafkaTopic, new SimpleStringSchema(), kafkaProperties
        );

        kafkaConsumer011.setStartFromEarliest();
        DataStream<String> kafkaSource = env.addSource(kafkaConsumer011)
                .name("KafkaSource")
                .uid("source-id-kafka-source");

        SingleOutputStreamOperator<
                Tuple4<String, String, String, Integer>
                > eventStream = kafkaSource.process(new ProcessFunction<String, Tuple4<String, String, String, Integer>>() {
            @Override
            public void processElement(String value, Context ctx, Collector<Tuple4<String, String, String, Integer>> out) throws Exception {
                try {
                    JSONObject obj = JSON.parseObject(value);
                    String userID = obj.getString("userID");
                    String eventTime = obj.getString("eventTime");
                    String eventType = obj.getString("eventType");
                    int productID = obj.getIntValue("productID");
                    out.collect(new Tuple4<>(userID, eventTime, eventType, productID));
                } catch (Exception e) {
                    log.warn("异常数据:{}",value, e);
                }
            }
        });



        //4、Mysql配置流
        //自定义Mysql Source，周期性地从Mysql中获取配置，并广播出去
        //数据: 用户ID,用户姓名，用户年龄
        DataStreamSource<Map<String, Tuple2<String, Integer>>>
                configStream = env.addSource(new MysqlSource(fromMysqlHost, fromMysqlPort, fromMysqlDB, fromMysqlUser,
                fromMysqlPasswd, fromMysqlSecondInterval));

        /*
          (1) 先建立MapStateDescriptor
          MapStateDescriptor定义了状态的名称、Key和Value的类型。
          这里，MapStateDescriptor中，key是Void类型，value是Map<String, Tuple2<String,Int>>类型。
         */
        MapStateDescriptor<Void, Map<String, Tuple2<String, Int>>>
                configDescriptor =
                new MapStateDescriptor<>("config",
                        Types.VOID,
                        Types.MAP(
                                Types.STRING,
                                Types.TUPLE(Types.STRING, Types.INT)
                        )
                );

        /*
          (2) 将配置流广播，形成BroadcastStream
         */
        BroadcastStream<Map<String, Tuple2<String, Integer>>>
                broadcastConfigStream = configStream.broadcast(configDescriptor);

        //5、事件流和广播的配置流连接，形成BroadcastConnectedStream
        BroadcastConnectedStream<Tuple4<String, String, String, Integer>, Map<String, Tuple2<String, Integer>>>
                connectedStream = eventStream.connect(broadcastConfigStream);


        //6、对BroadcastConnectedStream应用process方法，根据配置(规则)处理事件
        SingleOutputStreamOperator<Tuple6<String, String, String, Integer, String, Integer>>
                resultStream = connectedStream.process(new CustomBroadcastProcessFunction());


        //7、输出结果
        resultStream.print();

        //8、生成JobGraph，并开始执行
        env.execute("BroadCast state");

    }

    /**
     * 自定义BroadcastProcessFunction
     * 当事件流中的用户ID在配置中出现时，才对该事件处理, 并在事件中补全用户的基础信息
     * Tuple4<String, String, String, Integer>: 第一个流(事件流)的数据类型
     * Map<String, Tuple2<String, Integer>>: 第二个流(配置流)的数据类型
     * Tuple6<String, String, String, Integer,String, Integer>: 返回的数据类型
     */
    public static class CustomBroadcastProcessFunction
            extends BroadcastProcessFunction<
            Tuple4<String, String, String, Integer>,
            Map<String, Tuple2<String, Integer>>,
            Tuple6<String, String, String, Integer, String, Integer>> {

        /**定义MapStateDescriptor*/
        MapStateDescriptor<Void, Map<String, Tuple2<String,Integer>>> configDescriptor
                =
                new MapStateDescriptor<>(
                        "config",
                        Types.VOID,
                        Types.MAP(
                                Types.STRING,
                                Types.TUPLE(Types.STRING, Types.INT)
                        )
                );


        /**
         * 读取状态，并基于状态，处理事件流中的数据
         * 在这里，从上下文中获取状态，基于获取的状态，对事件流中的数据进行处理
         * @param value 事件流中的数据
         * @param ctx 上下文
         * @param out 输出零条或多条数据
         * @throws Exception
         */
        @Override
        public void processElement(Tuple4<String, String, String, Integer> value, ReadOnlyContext ctx, Collector<Tuple6<String, String, String, Integer, String, Integer>> out) throws Exception {
            //事件流中的用户ID
            String userID = value.f0;

            //获取状态
            ReadOnlyBroadcastState<Void, Map<String, Tuple2<String, Integer>>>
                    broadcastState = ctx.getBroadcastState(configDescriptor);
            Map<String, Tuple2<String, Integer>> broadcastStateUserInfo =
                    broadcastState.get(null);

            //配置中有此用户，则在该事件中添加用户的userName、userAge字段。
            //配置中没有此用户，则丢弃
            if (broadcastStateUserInfo != null) {
                Tuple2<String, Integer> userInfo = broadcastStateUserInfo.get(userID);
                if (userInfo != null) {
                    out.collect(new Tuple6<>(value.f0,value.f1,value.f2,value.f3,userInfo.f0,userInfo.f1));
                }
            }
        }

        /**
         * 处理广播流中的每一条数据，并更新状态
         * @param value 广播流中的数据
         * @param ctx 上下文
         * @param out 输出零条或多条数据
         * @throws Exception
         */
        @Override
        public void processBroadcastElement(Map<String, Tuple2<String, Integer>> value, Context ctx, Collector<Tuple6<String, String, String, Integer, String, Integer>> out) throws Exception {
            //获取状态
            BroadcastState<Void, Map<String, Tuple2<String, Integer>>>
                    broadcastState = ctx.getBroadcastState(configDescriptor);

            //清空状态
            broadcastState.clear();

            //更新状态
            broadcastState.put(null, value);


        }
    }

    /**
     * 自定义Mysql Source，每隔 secondInterval 秒从Mysql中获取一次配置
     */
    public static class MysqlSource extends RichSourceFunction<Map<String, Tuple2<String, Integer>>> {

        private String host;
        private Integer port;
        private String db;
        private String user;
        private String passwd;
        private Integer secondInterval;

        private volatile boolean isRunning = true;

        private Connection connection;
        private PreparedStatement preparedStatement;


        public MysqlSource(String host, Integer port, String db, String user, String passwd,Integer secondInterval) {
            this.host = host;
            this.port = port;
            this.db = db;
            this.user = user;
            this.passwd = passwd;
            this.secondInterval = secondInterval;
        }

        /**
         * 建立数据库连接
         * @param parameters
         * @throws Exception
         */
        @Override
        public void open(Configuration parameters) throws Exception {
            super.open(parameters);
            //装了8
            Class.forName("com.mysql.cj.jdbc.Driver");
            connection= DriverManager.getConnection("jdbc:mysql://"+host+":"+port+"/"+db+"?useUnicode=true&characterEncoding=UTF-8", user, passwd);
            String sql="select userID,userName,userAge from user_info";
            preparedStatement=connection.prepareStatement(sql);
        }

        /**
         * 关闭连接
         * @throws Exception
         */
        @Override
        public void close() throws Exception {
            super.close();
            if(connection!=null){
                connection.close();
            }

            if(preparedStatement !=null){
                preparedStatement.close();
            }
        }

        @Override
        public void run(SourceContext<Map<String, Tuple2<String, Integer>>> ctx) throws Exception {
            try {
                while (isRunning) {
                    Map<String, Tuple2<String, Integer>> output = new HashMap<>();
                    ResultSet resultSet = preparedStatement.executeQuery();
                    while (resultSet.next()){
                        String userID = resultSet.getString("userID");
                        String userName = resultSet.getString("userName");
                        int userAge = resultSet.getInt("userAge");
                        output.put(userID,new Tuple2<>(userName,userAge));
                    }

                    ctx.collect(output);
                    //每隔多少秒执行一次查询
                    Thread.sleep(1000*secondInterval);
                }
            } catch (Exception e) {
                log.error("从Mysql获取配置异常...", e);
            }
        }

        @Override
        public void cancel() {
            isRunning = false;
        }
    }
}
